#############################################################################################
# NKPH estimation methods
#############################################################################################


using NLopt
using DataFrames
using XLSX
using CSV
using Dates
using Statistics
using Random

# directories for accessing data, reading models, and saving estimates
MOMENT_DIR = "../targets/"
MODEL_DIR = "../models/"
ESTIM_DIR = "../models/estimates/"



#############################################################################################
# simple statistical functions to compute moments

function _calc_fwd_diff(x::AbstractVector, D_diff::Int)
    # forward diff data (monthly, adding missings)
    N_obs = length(x)
    D_x = Vector{Union{Missing, Float64}}(undef, N_obs)
    D_x[:] .= missing
    D_x[1:N_obs-D_diff] = x[D_diff+1:end] - x[1:N_obs-D_diff]
    return D_x
end

function _get_moment_data(var_label::Union{String, Symbol}, df::DataFrame)
    var_name = Symbol(var_label)
    return df[:, var_name]
end
function _calc_cov_missing(x::AbstractVector, y::AbstractVector)
    # covariance of x and y (jointly dropping missing values)
    idx_keep = ( .! ismissing.(x) ) .& ( .! ismissing.(y) )
    return cov(x[idx_keep], y[idx_keep])
end
function _calc_sd_missing(x::AbstractVector)
    # std dev of x (dropping missing values)
    return sqrt(_calc_cov_missing(x, x))
end
function _calc_corr_missing(x::AbstractVector, y::AbstractVector)
    # correlation of x and y (jointly dropping missing values)
    idx_keep = ( .! ismissing.(x) ) .& ( .! ismissing.(y) )
    return cor(x[idx_keep], y[idx_keep])
end
function _calc_coeff_missing(x::AbstractVector, y::AbstractVector)
    # regression coeff of x on y (jointly dropping missing values)
    idx_keep = ( .! ismissing.(x) ) .& ( .! ismissing.(y) )
    return cov(x[idx_keep], y[idx_keep]) / cov(y[idx_keep])
end


#############################################################################################

mutable struct NKPHEmpiricalMoments
    # inputs
    rates_fpath::String
    regs_fpath::String
    alt_regs_fpath::String
    rates_date_start::String
    rates_date_end::String

    # dataframes
    df_rates::DataFrame
    df_regs::DataFrame
    df_alt_regs::DataFrame

    # short/long maturities
    tau_short::Float64
    tau_long::Float64

    # sample indexes
    # note: indexes are applied after differencing
    # eg, D12_x_T = x_{T+12} - x_T is included

    rates_sample_idxs::Tuple{Int, Int}

    function NKPHEmpiricalMoments(
            rates_fpath::String, rates_date_start::String, rates_date_end::String,
            regs_fpath::String, alt_regs_fpath::String,  
            tau_arr::Vector{Float64},
            tau_short::Float64, tau_long::Float64,
            s_auto_diffs::Vector{Float64}
        )
        nkphemp = new(rates_fpath, regs_fpath, alt_regs_fpath, rates_date_start, rates_date_end)

        # read rates data
        nkphemp.df_rates = DataFrame(CSV.File(rates_fpath))
        # construct/redefine short and long maturity rates
        nkphemp.tau_short = tau_short
        nkphemp.tau_long = tau_long
        for (idx, tau) in enumerate((tau_short, tau_long))
            yld = nkphemp.df_rates[:, Symbol("y_" * string(Int(tau*12)))]
            yld_til = nkphemp.df_rates[:, Symbol("ytil_" * string(Int(tau*12)))]
            if idx==1
                # short
                nkphemp.df_rates[:, :ishort] = yld
                nkphemp.df_rates[:, :ishorttil] = yld_til
            else
                # long
                nkphemp.df_rates[:, :ilong] = yld
                nkphemp.df_rates[:, :ilongtil] = yld_til
            end
        end
        # construct additional yield curve objects
        for tau in tau_arr
            yld_sym = Symbol("y_" * string(Int(tau*12)))
            yld_til_sym = Symbol("ytil_" * string(Int(tau*12)))
            _idx_yld_sym = columnindex(nkphemp.df_rates, yld_sym)
            _idx_yld_til_sym = columnindex(nkphemp.df_rates, yld_sym)

            # slopes
            if _idx_yld_sym>0
                yld = nkphemp.df_rates[:, yld_sym]
                yld_slope_sym = Symbol("y_slope_" * string(Int(tau*12)))
                nkphemp.df_rates[:, yld_slope_sym] = yld - nkphemp.df_rates[:, :ishort]
            end
            if _idx_yld_til_sym>0
                yld_til = nkphemp.df_rates[:, yld_til_sym]
                yld_til_slope_sym = Symbol("ytil_slope_" * string(Int(tau*12)))
                nkphemp.df_rates[:, yld_til_slope_sym] = yld_til - nkphemp.df_rates[:, :ishort]
            end
            # spreads
            if _idx_yld_sym>0 && _idx_yld_til_sym>0
                yld = nkphemp.df_rates[:, yld_sym]
                yld_til = nkphemp.df_rates[:, yld_til_sym]
                yld_diff_sym = Symbol("y_diff_" * string(Int(tau*12)))
                nkphemp.df_rates[:, yld_diff_sym] = yld_til - yld

                # additionally specificy short/long diffs
                if tau == tau_short
                    nkphemp.df_rates[:, :ishort_diff] = yld_til - yld
                elseif tau == tau_long
                    nkphemp.df_rates[:, :ilong_diff] = yld_til - yld
                end
            end
        end
        # construct/redefine short and long maturity spreads
        for (idx, tau) in enumerate((tau_short, tau_long))
            yld = nkphemp.df_rates[:, Symbol("y_" * string(Int(tau*12)))]
            yld_til = nkphemp.df_rates[:, Symbol("ytil_" * string(Int(tau*12)))]
            if idx==1
                # short
                nkphemp.df_rates[:, :ishort_diff] = yld_til - yld
            else
                # long
                nkphemp.df_rates[:, :ilong_diff] = yld_til - yld
            end
        end

        # differences (all variables; drop time index)
        varnames = String.(names(nkphemp.df_rates))
        for varname in varnames[2:end]
            for s_auto in s_auto_diffs
                s_auto_m = Int(s_auto*12)
                diff_name = "D" * string(s_auto_m) * "_"
                varname_diff_sym = Symbol(diff_name * varname)
                nkphemp.df_rates[:, varname_diff_sym] = _calc_fwd_diff(
                    nkphemp.df_rates[:, Symbol(varname)], s_auto_m
                )
            end
        end

        # keep specified dates (after differencing)
        idx_start = findfirst(nkphemp.df_rates[:, :ym] .== rates_date_start)
        idx_end = findfirst(nkphemp.df_rates[:, :ym] .== rates_date_end)
        #nkphemp.rates_sample_idxs = (idx_start, idx_end)
        nkphemp.df_rates = nkphemp.df_rates[idx_start:idx_end, :]

        # read regression data
        df_regs = DataFrame(CSV.File(regs_fpath))
        # keep specified maturities
        tau_regs = Vector{Float64}(df_regs[:, :tau])
        idx_tau_keep = in.(tau_regs, Ref(tau_arr))
        nkphemp.df_regs = df_regs[idx_tau_keep, :]

        df_alt_regs = DataFrame(CSV.File(alt_regs_fpath))
        # keep specified maturities
        tau_regs = Vector{Float64}(df_alt_regs[:, :tau])
        idx_tau_keep = in.(tau_regs, Ref(tau_arr))
        nkphemp.df_alt_regs = df_alt_regs[idx_tau_keep, :]

        return nkphemp
    end
end

# using data labels and dataframes
function _calc_cov_missing(var_label_x::Union{String, Symbol}, var_label_y::Union{String, Symbol},
        nkphemp::NKPHEmpiricalMoments)
    x = _get_moment_data(var_label_x, nkphemp.df_rates)
    y = _get_moment_data(var_label_y, nkphemp.df_rates)
    return _calc_cov_missing(x, y)
end
function _calc_sd_missing(var_label_x::Union{String, Symbol},
        nkphemp::NKPHEmpiricalMoments)
    x = _get_moment_data(var_label_x, nkphemp.df_rates)
    return _calc_sd_missing(x)
end
function _calc_corr_missing(var_label_x::Union{String, Symbol}, var_label_y::Union{String, Symbol},
        nkphemp::NKPHEmpiricalMoments)
    x = _get_moment_data(var_label_x, nkphemp.df_rates)
    y = _get_moment_data(var_label_y, nkphemp.df_rates)
    return _calc_corr_missing(x, y)
end
function _calc_coeff_missing(var_label_x::Union{String, Symbol}, var_label_y::Union{String, Symbol},
        nkphemp::NKPHEmpiricalMoments)
    x = _get_moment_data(var_label_x, nkphemp.df_rates)
    y = _get_moment_data(var_label_y, nkphemp.df_rates)
    return _calc_coeff_missing(x, y)
end
function _get_cond_coeff(coeff_type::Symbol, tau::Float64,
        nkphemp::NKPHEmpiricalMoments)
    if coeff_type == :localization_short || coeff_type == :localization_long
        # baseline localization
        df_cond_coeffs = nkphemp.df_regs
    elseif coeff_type == :alt_localization_short || coeff_type == :alt_localization_long
        # alternative localization
        df_cond_coeffs = nkphemp.df_alt_regs
    else
        throw(DomainError("bad cond coeff coeff_type", coeff_type))
    end
    idx_tau = findfirst(nkphemp.df_regs[:, :tau] .== tau)
    return df_cond_coeffs[idx_tau, Symbol(coeff_type)]
end

# wrapper
function get_target_val(
        var_label_x::Union{String, Symbol},
        var_label_y::Union{String, Symbol},
        cov_type::Symbol,
        nkphemp::NKPHEmpiricalMoments)
    # compute second moment target values
    if cov_type == :cov
        val = _calc_cov_missing(var_label_x, var_label_y, nkphemp)
    elseif cov_type == :sd
        if var_label_x != var_label_y
            throw(DomainError("sd specified but x =/= y", (var_label_x, var_label_y)))
        end
        val = _calc_sd_missing(var_label_x, nkphemp)
    elseif cov_type == :corr
        val = _calc_corr_missing(var_label_x, var_label_y, nkphemp)
    elseif cov_type == :coeff
        val = _calc_coeff_missing(var_label_x, var_label_y, nkphemp)
    else
        throw(DomainError("bad scalar moment cov_type", cov_type))
    end
    return val
end
function get_target_val(coeff_type::Union{String, Symbol}, tau::Float64,
        nkphemp::NKPHEmpiricalMoments)
    # compute second moment target values
    val = _get_cond_coeff(Symbol(coeff_type), tau, nkphemp)
    return val
end


#############################################################################################
# estimation object

function _xls2dataframe(xls_fpath, xls_sheet)
    return DataFrame(XLSX.readtable(xls_fpath, xls_sheet)...)
end

mutable struct NKPHEstimationOptimizer
    model_fpath::String
    rates_fpath::String
    regs_fpath::String
    alt_regs_fpath::String
    rates_date_start::String
    rates_date_end::String
    model_name::String
    # empirical moments object
    nkphemp::NKPHEmpiricalMoments

    # NLopt optimization object
    nl_opt::Opt

    # underlying model objects
    nkphtgts::NKPHTargets
    nkphmmts::NKPHMoments
    nkphsolvr::NKPHSolver

    # array of maturities/time diffs
    tau_arr::Vector{Float64}
    tau_arr_long::Vector{Float64}
    tau_arr_til::Vector{Float64}
    tau_arr_reg::Vector{Float64}
    s_auto_arr::Vector{Float64}
    s_auto_diffs_arr::Vector{Float64}
    tau_short::Float64
    tau_short_idx::Int
    tau_long::Float64
    tau_long_idx::Int

    # array of moment variables/targets
    mmtvars_arr::Vector{MomentVariable}
    mmtloss_arr::Vector{MomentLoss}
    mmtvar_tau_lookup::Dict{Tuple{Symbol, Float64}, Int}
    mmtloss_group_lookup::Dict{Int, Vector{Int}}

    # model object dataframes
    df_est_opts::DataFrame
    df_soln_opts::DataFrame
    df_state_vars::DataFrame
    df_short_rate_mat::DataFrame
    df_maturity_opts::DataFrame
    df_params_fixed::DataFrame
    df_params_constraints::DataFrame
    df_target_covs::DataFrame

    # display options
    display::Bool
    N_display::Int

    # current best results
    minf_curr::Float64
    minx_curr::Vector{Float64}
    constraints_curr::Vector{Float64}

    function NKPHEstimationOptimizer(model_fpath::String,
            rates_fpath::String, regs_fpath::String, alt_regs_fpath::String,
            rates_date_start::String, rates_date_end::String,
            #tau_short::Float64, tau_long::Float64
        )
        nkphopt = new(model_fpath)
        # TODO optional input
        nkphopt.tau_arr = Vector{Float64}(range(1, 30, step=1.0))
        # model and moment files
        nkphopt.model_fpath = model_fpath
        nkphopt.rates_fpath = rates_fpath
        nkphopt.regs_fpath = regs_fpath
        nkphopt.alt_regs_fpath = alt_regs_fpath
        nkphopt.rates_date_start = rates_date_start
        nkphopt.rates_date_end = rates_date_end

        # read all model object dataframes
        # estimation and solution options
        nkphopt.df_est_opts = _xls2dataframe(model_fpath, "est_opts")
        nkphopt.df_soln_opts = _xls2dataframe(model_fpath, "soln_opts")
        nkphopt.df_state_vars = _xls2dataframe(model_fpath, "state_vars")
        nkphopt.df_short_rate_mat = _xls2dataframe(model_fpath, "short_rate_mat")
        nkphopt.df_maturity_opts = _xls2dataframe(model_fpath, "maturity_opts")
        nkphopt.df_params_fixed = _xls2dataframe(model_fpath, "params_fixed")
        nkphopt.df_params_constraints = _xls2dataframe(model_fpath, "params_constraints")
        nkphopt.df_target_covs = _xls2dataframe(model_fpath, "target_covs")

        # get maturity/time options
        tau_arr_long = nkphopt.df_maturity_opts[:, :tau]
        nkphopt.tau_arr_long = Vector{Float64}(tau_arr_long[ .! ismissing.(tau_arr_long)])
        tau_arr_til = nkphopt.df_maturity_opts[:, :tau_til]
        nkphopt.tau_arr_til = Vector{Float64}(tau_arr_til[ .! ismissing.(tau_arr_til)])
        tau_arr_reg = nkphopt.df_maturity_opts[:, :tau_reg]
        nkphopt.tau_arr_reg = Vector{Float64}(tau_arr_reg[ .! ismissing.(tau_arr_reg)])

        tau_short = Float64(nkphopt.df_maturity_opts[1, :tau_short])
        tau_long = Float64(nkphopt.df_maturity_opts[1, :tau_long])
        nkphopt.tau_short = tau_short
        nkphopt.tau_long = tau_long

        s_auto_arr = nkphopt.df_maturity_opts[:, :s_auto]
        nkphopt.s_auto_arr = Vector{Float64}(s_auto_arr[ .! ismissing.(s_auto_arr)])

        # finish setting up empirical moments, parameters, and estimator object
        #initialize_empirical_moments!(nkphopt)
        #initialize_model!(nkphopt)
        #setup_nlopt!(nkphopt)
        return nkphopt
    end
end

function initialize_empirical_moments!(nkphopt::NKPHEstimationOptimizer)::Nothing
    # setup empirical moments for estimation
    # indices
    nkphopt.tau_short_idx = findfirst(nkphopt.tau_arr .== nkphopt.tau_short)
    nkphopt.tau_long_idx = findfirst(nkphopt.tau_arr .== nkphopt.tau_long)

    # time diffs
    # include 0.0 as first element
    nkphopt.s_auto_diffs_arr = vcat(0.0, nkphopt.s_auto_arr)
    # include permutations of diffs
    N_s = length(nkphopt.s_auto_diffs_arr)
    for i=1:N_s
        for j=i:N_s
            s_diff = abs(nkphopt.s_auto_diffs_arr[j] - nkphopt.s_auto_diffs_arr[i])
            if !(s_diff in nkphopt.s_auto_diffs_arr)
                push!(nkphopt.s_auto_diffs_arr, s_diff)
            end
        end
    end

    # empirical moments
    nkphopt.nkphemp = NKPHEmpiricalMoments(
        nkphopt.rates_fpath, nkphopt.rates_date_start, nkphopt.rates_date_end,
        nkphopt.regs_fpath, nkphopt.alt_regs_fpath, 
        nkphopt.tau_arr,
        nkphopt.tau_short, nkphopt.tau_long, nkphopt.s_auto_arr
    )
    return nothing
end



# wrapper to get optimization options from DataFrame
function _get_solver_options(opt_name::String, nkphopt::NKPHEstimationOptimizer)
    idx = findall(x->x==opt_name, nkphopt.df_soln_opts[:, :continuation_opts])[1]
    return nkphopt.df_soln_opts[idx, :value]
end
function _get_optimization_options(opt_name::String, nkphopt::NKPHEstimationOptimizer)
    idx = findall(x->x==opt_name, nkphopt.df_est_opts[:, :optimization_opts])[1]
    return nkphopt.df_est_opts[idx, :value]
end
function initialize_model!(nkphopt::NKPHEstimationOptimizer)
    # model name
    nkphopt.model_name = String(_get_optimization_options("model_name", nkphopt))
    # create model objects
    a = Float64(_get_solver_options("tracking_target", nkphopt))
    ei_vec = Vector{Float64}(nkphopt.df_short_rate_mat[:, :ei_vec])
    ed_vec = Vector{Float64}(nkphopt.df_short_rate_mat[:, :ed_vec])
    param_names = [Symbol(pname) for pname in nkphopt.df_params_constraints[:, :param]]
    # model setup (depends on model_name)
    # NOTE: full, semiA, semiB no longer supported
    if nkphopt.model_name == "semiC"
        a = Float64(_get_solver_options("tracking_target", nkphopt))

        nkphsolvr = initialize_nkphsolvr_semiC(a,
            nkphopt.df_short_rate_mat,
            nkphopt.df_params_constraints,
            nkphopt.df_params_fixed,
        )
    elseif nkphopt.model_name == "semiC_QE"
        a = Float64(_get_solver_options("tracking_target", nkphopt))

        nkphsolvr = initialize_nkphsolvr_semiC_QE(a,
            nkphopt.df_short_rate_mat,
            nkphopt.df_params_constraints,
            nkphopt.df_params_fixed,
        )
    elseif nkphopt.model_name == "semiC_MP"
        a = Float64(_get_solver_options("tracking_target", nkphopt))

        nkphsolvr = initialize_nkphsolvr_semiC_MP(a,
            nkphopt.df_short_rate_mat,
            nkphopt.df_params_constraints,
            nkphopt.df_params_fixed,
        )
    elseif nkphopt.model_name == "semiC_QE_passive"
        a = Float64(_get_solver_options("tracking_target", nkphopt))

        nkphsolvr = initialize_nkphsolvr_semiC_QE_passive(a,
            nkphopt.df_short_rate_mat,
            nkphopt.df_params_constraints,
            nkphopt.df_params_fixed,
        )
    else
        throw(DomainError(nkphopt.model_name, "unsupported model_name"))
    end

    # moment object
    nkphmmts = NKPHMoments(nkphopt.tau_short, nkphopt.tau_long,
        nkphopt.tau_arr, nkphopt.s_auto_diffs_arr, nkphsolvr)

    # get all unique mmtvars
    varnames_moment_type_i = [x for x in zip(Vector{String}(nkphopt.df_target_covs[:, :var_i]),
        Vector{String}(nkphopt.df_target_covs[:, :moment_type_i]))]
    varnames_moment_type_j = [x for x in zip(Vector{String}(nkphopt.df_target_covs[:, :var_j]),
        Vector{String}(nkphopt.df_target_covs[:, :moment_type_j]))]
    varnames_moment_type_all = unique(vcat(varnames_moment_type_i, varnames_moment_type_j))

    mmtvars_arr = Vector{MomentVariable}(undef, 0)
    for (varname, moment_type) in varnames_moment_type_all
        varname_sym = Symbol(varname)
        if moment_type == "scalar"
            mmtvar = MomentVariable(varname_sym, 0.0, nkphmmts)
            push!(mmtvars_arr, mmtvar)
        elseif moment_type == "tau"
            for tau in nkphopt.tau_arr_long
                mmtvar = MomentVariable(varname_sym, tau, nkphmmts)
                push!(mmtvars_arr, mmtvar)
            end
        elseif moment_type == "tau_til"
            for tau in nkphopt.tau_arr_til
                mmtvar = MomentVariable(varname_sym, tau, nkphmmts)
                push!(mmtvars_arr, mmtvar)
            end
        elseif moment_type == "tau_reg"
            for tau in nkphopt.tau_arr_reg
                mmtvar = MomentVariable(varname_sym, tau, nkphmmts)
                push!(mmtvars_arr, mmtvar)
            end
        end
    end
    # for mmtvar lookups
    mmtvar_tau_lookup = Dict{Tuple{Symbol, Float64}, Int}()
    for (idx, mmtvar) in enumerate(mmtvars_arr)
        mmtvar_tau_lookup[(mmtvar.variable_name, mmtvar.tau)] = idx
    end
    # assign objects
    nkphopt.nkphsolvr = nkphsolvr
    nkphopt.nkphmmts = nkphmmts
    nkphopt.mmtvars_arr = mmtvars_arr
    nkphopt.mmtvar_tau_lookup = mmtvar_tau_lookup
    return nothing
end


function initialize_targets!(nkphopt::NKPHEstimationOptimizer)
    # create targets
    # mapping loss targets to mmtvars
    mmtvar_tau_lookup = nkphopt.mmtvar_tau_lookup

    mmtloss_arr = Vector{MomentLoss}(undef, 0)
    N_target_groups = size(nkphopt.df_target_covs, 1)

    # for mmtloss group lookups
    mmtloss_group_lookup = Dict{Int, Vector{Int}}()
    idx_moment = 0

    for idx_group=1:N_target_groups
        # get moment options
        var_i, var_j, moment_type_i, moment_type_j, cov_type, idx_condcoeff, coeff_type, wgt =
            nkphopt.df_target_covs[idx_group, :]
        var_i_sym, var_j_sym, cov_type_sym = Symbol.([var_i, var_j, cov_type])
        if moment_type_i == "scalar" && moment_type_j == "scalar"
            # scalar moment
            # get bhat and mmtvars
            bhat = get_target_val(var_i, var_j, cov_type_sym, nkphopt.nkphemp)
            mmtvar_idxL = mmtvar_tau_lookup[(var_i_sym, 0.0)]
            mmtvar_idxR = mmtvar_tau_lookup[(var_j_sym, 0.0)]
            mmtloss = MomentLoss(cov_type_sym, wgt/1.0, bhat, idx_condcoeff,
                nkphopt.mmtvars_arr[mmtvar_idxL], nkphopt.mmtvars_arr[mmtvar_idxR] )
            push!(mmtloss_arr, mmtloss)

            # mapping from loss groups to individial loss objects
            idx_moment += 1
            mmtloss_group_lookup[idx_group] = [idx_moment]
        elseif moment_type_i in ["tau", "tau_til", "tau_reg"]
            # function of maturity
            if moment_type_i == "tau"
                tau_arr_mmt = nkphopt.tau_arr_long
            elseif moment_type_i == "tau_til"
                tau_arr_mmt = nkphopt.tau_arr_til
            else
                tau_arr_mmt = nkphopt.tau_arr_reg
            end
            idx_moment_vec = Vector{Int}(undef, 0)
            # weight scaled by number of maturities
            wgt_scale = wgt / length(tau_arr_mmt)
            for tau in tau_arr_mmt
                # get correct variable name (using monthly maturity)
                var_i_tau = var_i * "_" * string(Int(12*tau))
                # get model mmtvars and empirical bhat
                mmtvar_idxL = mmtvar_tau_lookup[(var_i_sym, tau)]
                if moment_type_j in ["tau", "tau_til", "tau_reg"]
                    var_j_tau = var_j * "_" * string(Int(12*tau))
                    mmtvar_idxR = mmtvar_tau_lookup[(var_j_sym, tau)]
                elseif moment_type_j == "scalar"
                    var_j_tau = var_j
                    mmtvar_idxR = mmtvar_tau_lookup[(var_j_sym, 0.0)]
                else
                    throw(DomainError((moment_type_i, moment_type_j), "bad moments combo"))
                end
                # separately initialize regression coeffs or unconditional second moments
                if ismissing(coeff_type)
                    bhat = get_target_val(var_i_tau, var_j_tau, cov_type_sym, nkphopt.nkphemp)
                else
                    bhat = get_target_val(coeff_type, tau, nkphopt.nkphemp)
                end
                mmtloss = MomentLoss(cov_type_sym, wgt_scale, bhat, idx_condcoeff,
                    nkphopt.mmtvars_arr[mmtvar_idxL], nkphopt.mmtvars_arr[mmtvar_idxR] )
                push!(mmtloss_arr, mmtloss)

                # mapping from loss groups to individial loss objects
                idx_moment += 1
                push!(idx_moment_vec, idx_moment)
            end
            # mapping from loss groups to individial loss objects
            mmtloss_group_lookup[idx_group] = idx_moment_vec
        else
            throw(DomainError((moment_type_i, moment_type_j), "not supported moment type"))
        end
    end

    # create targets object and assign
    nkphopt.mmtloss_arr = mmtloss_arr
    nkphopt.mmtloss_group_lookup = mmtloss_group_lookup
    nkphopt.nkphtgts = NKPHTargets(nkphopt.mmtvars_arr, nkphopt.mmtloss_arr, nkphopt.nkphmmts)
    return nothing
end

function setup_nlopt!(nkphopt::NKPHEstimationOptimizer)::Tuple{Function, Function, Function}
    # pull out underlying model object
    nkphtgts = nkphopt.nkphtgts
    nkphsolvr = nkphopt.nkphsolvr
    J, N_params, N_params_all = nkphsolvr.J, nkphsolvr.N_params, nkphsolvr.N_params_all
    # display options
    display = _get_optimization_options("display", nkphopt)
    N_display = _get_optimization_options("N_display", nkphopt)

    # get optimizer options from excel
    algorithm = _get_optimization_options("algorithm", nkphopt)
    # create NLopt object from algorithm and dimensionality of problem
    nl_opt = Opt(Symbol(algorithm), N_params_all)

    # set stopping criteria options
    maxeval = _get_optimization_options("maxeval", nkphopt)
    if !ismissing(maxeval)
        nl_opt.maxeval = maxeval
    end
    xtol_rel = _get_optimization_options("xtol_rel", nkphopt)
    if !ismissing(xtol_rel)
        nl_opt.xtol_rel = xtol_rel
    end
    xtol_abs = _get_optimization_options("xtol_abs", nkphopt)
    if !ismissing(xtol_abs)
        nl_opt.xtol_abs = xtol_abs
    end
    ftol_rel = _get_optimization_options("ftol_rel", nkphopt)
    if !ismissing(ftol_rel)
        nl_opt.ftol_rel = ftol_rel
    end
    ftol_abs = _get_optimization_options("ftol_abs", nkphopt)
    if !ismissing(ftol_abs)
        nl_opt.ftol_abs = ftol_abs
    end
    stopval = _get_optimization_options("stopval", nkphopt)
    if !ismissing(stopval)
        nl_opt.stopval = stopval
    end
    maxtime = _get_optimization_options("maxtime", nkphopt)
    if !ismissing(maxtime)
        nl_opt.maxtime = maxtime
    end
    initial_step = _get_optimization_options("initial_step", nkphopt)
    if !ismissing(initial_step)
        nl_opt.initial_step = initial_step
    end
    population = _get_optimization_options("population", nkphopt)
    if !ismissing(population)
        nl_opt.population = Int(population)
    end
    # set seed for NLopt
    nlopt_seed = _get_optimization_options("nlopt_seed", nkphopt)
    if !ismissing(nlopt_seed)
        NLopt.srand(nlopt_seed)
    end

    # upper and lower bounds for params; AR_hat/M unconstrained
    _lb = -Inf .* ones(N_params_all)
    _lb[1:N_params] = Vector{Float64}(nkphopt.df_params_constraints[:, :min])
    _ub = Inf .* ones(N_params_all)
    _ub[1:N_params] = Vector{Float64}(nkphopt.df_params_constraints[:, :max])
    # small perturbation for constrained params to avoid errors
    for idx=1:N_params
        if _lb[idx] == _ub[idx]
            _ub[idx] += 1e-10
        end
    end
    nl_opt.lower_bounds = _lb
    nl_opt.upper_bounds = _ub

    # eigenvalue inequality constraints
    ineq_constraint_tol = _get_optimization_options("ineq_constraint_tol", nkphopt)
    # impose eigenvalue constraint for M
    Upsilon_mineigval = _get_optimization_options("Upsilon_mineigval", nkphopt)
    M_mineigval = _get_optimization_options("M_mineigval", nkphopt)
    M_mineig_constraint = _get_optimization_options("M_mineig_constraint", nkphopt)

    # model solution
    eq_constraint_tol = _get_optimization_options("eq_constraint_tol", nkphopt)

    # assign optimization functions
    _count_display = 0
    # note: NLopt gradient object dimensions are (N_params_all x N_constraints)
    function _min_objective(x::Vector{Float64}, loss_grad::Vector{Float64})::Float64
        if display
            _count_display += 1
            if mod(_count_display, N_display) == 0
                print("."); flush(stdout)
            end
        end
        #print(".")
        update_nkphtgts!(x, nkphtgts)
        if length(loss_grad) > 0
            deriv_nkphtgts!(nkphtgts)
            loss_grad[:] = nkphtgts.dloss_val
        end
        return nkphtgts.loss_val[1]
    end
    function _eq_constraints(
            root_vec::Vector{Float64},
            x::Vector{Float64},
            root_grad::Matrix{Float64}
        )::Nothing
        update_nkphsolvr!(x, nkphsolvr)
        root_vec[:] = nkphsolvr.root_vec
        if length(root_grad) > 0
            deriv_nkphsolvr!(nkphsolvr)
            root_grad[:, :] = nkphsolvr.droot_vec'
        end
        return nothing
    end
    function _ineq_constraints(
            eigval_vec::Vector{Float64},
            x::Vector{Float64},
            eigval_grad::Matrix{Float64}
        )::Nothing
        # note: NLopt inequality constraints always of the form f(x) <= 0
        # switch sign for eigvals required to be positive
        update_nkphtgts!(x, nkphtgts)
        # Upsilon: check that satisfies both pos and neg eigval constraints
        # smallest positive eigval must be >= eps
        eigval_vec[1] = (Upsilon_mineigval - nkphtgts.Upsilon_min_pos_eigval[1])
        # largest negative eigval must be <= -eps
        eigval_vec[2] = (nkphtgts.Upsilon_max_neg_eigval[1] + Upsilon_mineigval)
        if M_mineig_constraint
            # smallest M eigenvalue must be >= eps
            eigval_vec[3] = (M_mineigval - nkphtgts.M_min_eigval[1])
        end
        if length(eigval_grad) > 0
            deriv_nkphtgts!(nkphopt.nkphtgts)
            # negative sign from NLopt inequality conventions
            eigval_grad[:, 1] = -nkphtgts.dUpsilon_min_pos_eigval
            # positive sign from NLopt inequality conventions
            eigval_grad[:, 2] = nkphtgts.dUpsilon_max_neg_eigval
            if M_mineig_constraint
                # negative sign from NLopt inequality conventions
                eigval_grad[:, 3] = -nkphtgts.dM_min_eigval
            end
        end
        return nothing
    end

    # assign to optimizer object
    nl_opt.min_objective = _min_objective
    _eq_constr_tol_arr = eq_constraint_tol .* ones(J^2 + J)
    equality_constraint!(nl_opt, _eq_constraints, _eq_constr_tol_arr)
    if M_mineig_constraint
        _ineq_constr_tol_arr = ineq_constraint_tol .* ones(3)
    else
        _ineq_constr_tol_arr = ineq_constraint_tol .* ones(2)
    end
    inequality_constraint!(nl_opt, _ineq_constraints, _ineq_constr_tol_arr)
    nkphopt.nl_opt = nl_opt
    return _min_objective, _eq_constraints, _ineq_constraints
end


function _draw_init_params(
        rng::MersenneTwister,
        nkphopt::NKPHEstimationOptimizer,
        draw_init_method::Int,
        draw_init_std::Float64
    )::Vector{Float64}
    # draw random seed
    if draw_init_method==1
        # uniform draw
        z0 = rand(rng, nkphopt.nkphsolvr.N_params)
        # scaled by parameter bounds
        lb = nkphopt.nl_opt.lower_bounds[1:nkphopt.nkphsolvr.N_params]
        ub = nkphopt.nl_opt.upper_bounds[1:nkphopt.nkphsolvr.N_params]
        p0 = lb + z0 .* (ub - lb)
    elseif draw_init_method==2
        # normal random perturbation
        p0 = copy(Vector{Float64}(nkphopt.df_params_constraints[:, :init]))
        # scale std dev of draws
        p0_scale = ones(length(p0))
        for i=1:length(p0)
            if abs(p0[i]) < 0.01
                p0_scale[i] = 0.01
            elseif abs(p0[i]) > 1.0
                p0_scale[i] = 1.0
            end
        end
        z0 = randn(rng, nkphopt.nkphsolvr.N_params)
        p0 += draw_init_std .* p0_scale .* z0
        # make sure parameters are within bounds
        lb = nkphopt.nl_opt.lower_bounds[1:nkphopt.nkphsolvr.N_params]
        ub = nkphopt.nl_opt.upper_bounds[1:nkphopt.nkphsolvr.N_params]
        idx_min = p0 .< lb
        idx_max = p0 .> ub
        p0[idx_min] = lb[idx_min]
        p0[idx_max] = ub[idx_max]
    else
        throw(DomainError(draw_init_method, "bad draw method"))
    end
    return p0
end

function draw_init_params(
        seed::Int,
        nkphopt::NKPHEstimationOptimizer,
        draw_init_method::Int,
        draw_init_std::Float64;
        draw_valid_p0::Bool=true, N_draws::Int=100
    )::Vector{Float64}
    if draw_init_method==0
        p0 = copy(Vector{Float64}(nkphopt.df_params_constraints[:, :init]))
    else
        # set seed
        rng = MersenneTwister(seed)
        if !draw_valid_p0
            p0 = _draw_init_params(rng, nkphopt, draw_init_method, draw_init_std)
        else
            println("drawing valid p0:")
            _valid_draw = false
            for i=1:N_draws
                print(".")
                p0 = _draw_init_params(rng, nkphopt, draw_init_method, draw_init_std)
                # check if model solves correctly
                update_nkphsolvr_params!(p0, nkphopt.nkphsolvr)
                solved_model, m_soln = solve_nkphsolvr_continuation!(100, nkphopt.nkphsolvr;
                    method=:trust_region, ftol=1e-8)
                # check eigenvalues
                x0 = vcat(p0, m_soln)
                update_nkphtgts!(x0, nkphopt.nkphtgts)
                solved_model &= (nkphopt.nkphtgts.Upsilon_min_pos_eigval[1] >
                    _get_optimization_options("Upsilon_mineigval", nkphopt))
                solved_model &= (nkphopt.nkphtgts.Upsilon_max_neg_eigval[1] <
                    _get_optimization_options("Upsilon_mineigval", nkphopt))
                if _get_optimization_options("M_mineig_constraint", nkphopt)
                    solved_model &= (nkphopt.nkphtgts.M_min_eigval[1] >
                        _get_optimization_options("M_mineigval", nkphopt))
                end
                if solved_model
                    _valid_draw = true
                    println()
                    break
                end
            end
            if !_valid_draw
                println("WARNING: did not draw valid p0")
            end
        end
    end
    return p0
end


function set_continuation_targets!(
        p0::Vector{Float64},
        nkphopt::NKPHEstimationOptimizer,
    )::Vector{Float64}
    # estimate model from initial params p0
    # fill in endog params (not used)
    x0 = zeros(nkphopt.nkphsolvr.N_params_all)
    x0[1:nkphopt.nkphsolvr.N_params] = p0
    update_nkphsolvr!(x0, nkphopt.nkphsolvr)
    solved_model, z_soln = solve_nkphsolvr_continuation!(100, nkphopt.nkphsolvr;
        method=:trust_region, ftol=1e-8)
    if !solved_model
        println("WARNING: did not solve model")
    end
    x0 = vcat(p0, z_soln)
    update_nkphtgts!(x0, nkphopt.nkphtgts)
    set_continuation_targets!(x0, nkphopt.nkphtgts; valid_tol=1e-7)
    return x0
end



function estimate_model_continuation!(
        p0::Vector{Float64},
        N_steps::Int,
        nkphopt::NKPHEstimationOptimizer,
    )::Tuple{Float64, Vector{Float64}, Symbol}
    # estimate model from initial params p0
    x0 = set_continuation_targets!(p0, nkphopt)
    x = copy(x0)
    # estimate via continuation algorithm
    ret = :INIT
    loss_val = nkphopt.nkphtgts.loss_val[1]
    for t=range(0.0, 1.0, length=N_steps)
        # reset underlying param values to force update on each iteration
        nkphopt.nkphsolvr.x_dict.vals[:] .= 0.
        nkphopt.nkphsolvr.dx_dict.vals[:] .= 0.
        nkphopt.nkphmmts.x_dict.vals[:] .= 0.
        nkphopt.nkphmmts.dx_dict.vals[:] .= 0.
        nkphopt.nkphtgts.x_dict.vals[:] .= 0.
        nkphopt.nkphtgts.dx_dict.vals[:] .= 0.

        println("="^50)
        println("optimizing for t=", t)
        update_continuation_targets!(t, nkphopt.nkphtgts)
        (optf, optx, ret) = NLopt.optimize(nkphopt.nl_opt, x)
        println()
        if ret == :FORCED_STOP
            println("WARNING: forced stop")
            # allow for user break
            sleep(2)
        end
        println("="^20)
        println(ret)
        println("best val/params:")
        println(optf)
        println(optx[1:nkphopt.nkphsolvr.N_params])
        println("AR/M params:")
        println(optx[nkphopt.nkphsolvr.N_params+1:end])
        println("="^20)

        # update with params and check solution
        update_nkphtgts!(optx, nkphopt.nkphtgts)
        loss_val = nkphopt.nkphtgts.loss_val[1]
        println("="^20)
        println("checking loss: ", loss_val)
        println("checking endog root")
        println(maximum(abs.(nkphopt.nkphsolvr.root_vec)))
        println("M min eigval (real): ")
        println(nkphopt.nkphtgts.M_min_eigval[1])
        println("Upsilon min pos/max neg eigvals (real): ")
        println(nkphopt.nkphtgts.Upsilon_min_pos_eigval[1])
        println(nkphopt.nkphtgts.Upsilon_max_neg_eigval[1])
        println("="^20)

        # TODO optionally allow for comparing solution to continuation solution
        # println("checking solution")
        # z_vec = optx[nkphopt.nkphsolvr.N_params+1:end]
        # solved_model, z_soln = solve_nkphsolvr_continuation!(100, nkphopt.nkphsolvr;
        #     method=:trust_region, ftol=1e-8)
        # println(solved_model)
        # check_soln = maximum(abs.(z_soln - z_vec))
        # println(check_soln)
        # if !solved_model || check_soln>1e-7
        #     println("WARNING: solution diverged")
        #     error("quitting")
        # end

        x[:] = copy(optx)
    end
    return (loss_val, x, ret)
end




# saving results
# function _df2xls(df::DataFrame)
#     return collect(DataFrames.eachcol(df)), DataFrames.names(df)
# end

function _df2xls_sheet!(df::DataFrame, xls_fpath::String, sheet_name::String)
    # write dataframe to existing xls file
    XLSX.openxlsx(xls_fpath, mode="rw") do xf
        N_sheets = XLSX.sheetcount(xf)
        sheetnames_all = XLSX.sheetnames(xf)
        idx_sheet = findfirst(sheet_name .== sheetnames_all)
        if idx_sheet===nothing
            XLSX.addsheet!(xf, sheet_name)
            sheet = xf[N_sheets+1]
        else
            sheet = xf[idx_sheet]
        end
        XLSX.writetable!(sheet, df)
    end
    return nothing
end

function save_estimates(
        loss_val::Float64,
        x::Vector{Float64},
        est_fname::String,
        nkphopt::NKPHEstimationOptimizer
    )
    # update model with specified parameters and recompute loss
    nkphsolvr = nkphopt.nkphsolvr
    nkphtgts = nkphopt.nkphtgts
    nkphemp = nkphopt.nkphemp
    J, N_params = nkphsolvr.J, nkphsolvr.N_params
    update_nkphtgts!(x, nkphtgts)
    check_loss_val = nkphtgts.loss_val[1]
    if abs(loss_val - check_loss_val) > 1e-10
        println("WARNING: check does not match loss_val")
    end
    # params/AR_hat/M matrix to dataframe
    p = x[1:N_params]
    z_vec = x[N_params+1:end]
    AR_vec = z_vec[1:J]
    m_vec = z_vec[J+1:end]
    df_params_est = DataFrame(
        param_names = [pname for pname in nkphopt.df_params_constraints[:, :param]],
        params = copy(p)
    )
    df_AR = DataFrame(copy(reshape(AR_vec, J, 1)), :auto)
    df_M = DataFrame(copy(reshape(m_vec, J, J)), :auto)
    root_max = maximum(abs.(nkphsolvr.root_vec))
    M_min_eigval = nkphtgts.M_min_eigval[1]
    Upsilon_min_pos_eigval = nkphtgts.Upsilon_min_pos_eigval[1]
    Upsilon_max_neg_eigval = nkphtgts.Upsilon_max_neg_eigval[1]

    # compare solution with continuation solver
    println("checking solution")
    solved_model, z_soln = solve_nkphsolvr_continuation!(100, nkphsolvr;
        method=:trust_region, ftol=1e-8)
    println(solved_model)
    check_soln = maximum(abs.(z_soln - z_vec))
    println(check_soln)
    if !solved_model || check_soln>1e-7
        println("WARNING: solution diverged")
    end
    println("final loss values:")
    println(loss_val)

    # check loss/constraint values
    df_constr = DataFrame(
        constraints = ["loss_val", "check_loss_val",
            "root_max", "M_min_eigval", "Upsilon_min_pos_eigval", "Upsilon_max_neg_eigval",
            "solved_model", "check_soln",
            "model_fpath", "model_name",
            "rates_fpath", "regs_fpath", "rates_date_start", "rates_date_end"],
        values = [loss_val, check_loss_val,
            root_max, M_min_eigval, Upsilon_min_pos_eigval, Upsilon_max_neg_eigval,
            solved_model, check_soln,
            nkphopt.model_fpath, nkphopt.model_name,
            nkphemp.rates_fpath, nkphemp.regs_fpath,
            nkphemp.rates_date_start, nkphemp.rates_date_end]
    )

    # save estimates to xls file
    est_fpath = ESTIM_DIR * est_fname * ".xlsx"
    println("saving to:")
    println(est_fpath)
    
    XLSX.writetable(est_fpath, overwrite=true,
        "constraints" => df_constr,
        "parameters" => df_params_est,
        "AR" => df_AR,
        "M" => df_M,
        "target_covs" => nkphopt.df_target_covs,
    )
    return nothing
end
